#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import base64
import decimal
import datetime
import json
import struct
from array import array
from decimal import Decimal
from typing import Any, Callable, Dict, List, Tuple
from pyspark.errors import (
    PySparkNotImplementedError,
    PySparkValueError,
)
from zoneinfo import ZoneInfo


class VariantUtils:
    """
    A utility class for VariantVal.

    Adapted from library at: org.apache.spark.types.variant.VariantUtil
    """

    BASIC_TYPE_BITS = 2
    BASIC_TYPE_MASK = 0x3
    TYPE_INFO_MASK = 0x3F
    # The inclusive maximum value of the type info value. It is the size limit of `SHORT_STR`.
    MAX_SHORT_STR_SIZE = 0x3F

    # Below is all possible basic type values.
    # Primitive value. The type info value must be one of the values in the below section.
    PRIMITIVE = 0
    # Short string value. The type info value is the string size, which must be in `[0,
    # MAX_SHORT_STR_SIZE]`.
    # The string content bytes directly follow the header byte.
    SHORT_STR = 1
    # Object value. The content contains a size, a list of field ids, a list of field offsets, and
    # the actual field data. The length of the id list is `size`, while the length of the offset
    # list is `size + 1`, where the last offset represent the total size of the field data. The
    # fields in an object must be sorted by the field name in alphabetical order. Duplicate field
    # names in one object are not allowed.
    # We use 5 bits in the type info to specify the integer type of the object header: it should
    # be 0_b4_b3b2_b1b0 (MSB is 0), where:
    # - b4 specifies the type of size. When it is 0/1, `size` is a little-endian 1/4-byte
    # unsigned integer.
    # - b3b2/b1b0 specifies the integer type of id and offset. When the 2 bits are  0/1/2, the
    # list contains 1/2/3-byte little-endian unsigned integers.
    OBJECT = 2
    # Array value. The content contains a size, a list of field offsets, and the actual element
    # data. It is similar to an object without the id list. The length of the offset list
    # is `size + 1`, where the last offset represent the total size of the element data.
    # Its type info should be: 000_b2_b1b0:
    # - b2 specifies the type of size.
    # - b1b0 specifies the integer type of offset.
    ARRAY = 3

    # Below is all possible type info values for `PRIMITIVE`.
    # JSON Null value. Empty content.
    NULL = 0
    # True value. Empty content.
    TRUE = 1
    # False value. Empty content.
    FALSE = 2
    # 1-byte little-endian signed integer.
    INT1 = 3
    # 2-byte little-endian signed integer.
    INT2 = 4
    # 4-byte little-endian signed integer.
    INT4 = 5
    # 4-byte little-endian signed integer.
    INT8 = 6
    # 8-byte IEEE double.
    DOUBLE = 7
    # 4-byte decimal. Content is 1-byte scale + 4-byte little-endian signed integer.
    DECIMAL4 = 8
    # 8-byte decimal. Content is 1-byte scale + 8-byte little-endian signed integer.
    DECIMAL8 = 9
    # 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed integer.
    DECIMAL16 = 10
    # Date value. Content is 4-byte little-endian signed integer that represents the number of days
    # from the Unix epoch.
    DATE = 11
    # Timestamp value. Content is 8-byte little-endian signed integer that represents the number of
    # microseconds elapsed since the Unix epoch, 1970-01-01 00:00:00 UTC. This is a timezone-aware
    # field and when reading into a Python datetime object defaults to the UTC timezone.
    TIMESTAMP = 12
    # Timestamp_ntz value. It has the same content as `TIMESTAMP` but should always be interpreted
    # as if the local time zone is UTC.
    TIMESTAMP_NTZ = 13
    # 4-byte IEEE float.
    FLOAT = 14
    # Binary value. The content is (4-byte little-endian unsigned integer representing the binary
    # size) + (size bytes of binary content).
    BINARY = 15
    # Long string value. The content is (4-byte little-endian unsigned integer representing the
    # string size) + (size bytes of string content).
    LONG_STR = 16
    # year-month interval value. The content is one byte representing the start and end field values
    # (1 bit each starting at least significant bits) and a 4-byte little-endian signed integer
    YEAR_MONTH_INTERVAL = 19
    # day-time interval value. The content is one byte representing the start and end field values
    # (2 bits each starting at least significant bits) and an 8-byte little-endian signed integer
    DAY_TIME_INTERVAL = 20

    U32_SIZE = 4

    EPOCH = datetime.datetime(
        year=1970, month=1, day=1, hour=0, minute=0, second=0, tzinfo=datetime.timezone.utc
    )
    EPOCH_NTZ = datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0)

    MAX_DECIMAL4_PRECISION = 9
    MAX_DECIMAL4_VALUE = 10**MAX_DECIMAL4_PRECISION
    MAX_DECIMAL8_PRECISION = 18
    MAX_DECIMAL8_VALUE = 10**MAX_DECIMAL8_PRECISION
    MAX_DECIMAL16_PRECISION = 38
    MAX_DECIMAL16_VALUE = 10**MAX_DECIMAL16_PRECISION

    # There is no PySpark equivalent of the SQL year-month interval type. This class acts as a
    # placeholder for this type
    class _PlaceholderYearMonthIntervalInternalType:
        pass

    @classmethod
    def to_json(cls, value: bytes, metadata: bytes, zone_id: str = "UTC") -> str:
        """
        Convert the VariantVal to a JSON string. The `zone_id` parameter denotes the time zone that
        timestamp fields should be parsed in. It defaults to "UTC". The list of valid zone IDs can
        found by importing the `zoneinfo` module and running `zoneinfo.available_timezones()`.
        :return: JSON string
        """
        return cls._to_json(value, metadata, 0, zone_id)

    @classmethod
    def to_python(cls, value: bytes, metadata: bytes) -> str:
        """
        Convert the VariantVal to a nested Python object of Python data types.
        :return: Python representation of the Variant nested structure
        """
        return cls._to_python(value, metadata, 0)

    @classmethod
    def _read_long(cls, data: bytes, pos: int, num_bytes: int, signed: bool) -> int:
        cls._check_index(pos, len(data))
        cls._check_index(pos + num_bytes - 1, len(data))
        return int.from_bytes(data[pos : pos + num_bytes], byteorder="little", signed=signed)

    @classmethod
    def _check_index(cls, pos: int, length: int) -> None:
        if pos < 0 or pos >= length:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]:
        """
        Returns the (basic_type, type_info) pair from the given position in the value.
        """
        basic_type = value[pos] & VariantUtils.BASIC_TYPE_MASK
        type_info = (value[pos] >> VariantUtils.BASIC_TYPE_BITS) & VariantUtils.TYPE_INFO_MASK
        return (basic_type, type_info)

    @classmethod
    def _get_day_time_interval_fields(cls, value: bytes, pos: int) -> Tuple[int, int]:
        """
        Returns the (start_field, end_field) pair for a variant representing a day-time interval
        value stored at a given position in the value.
        """
        cls._check_index(pos, len(value))
        start_field = value[pos] & 0x3
        end_field = (value[pos] >> 2) & 0x3
        if end_field < start_field:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        return (start_field, end_field)

    @classmethod
    def _get_year_month_interval_fields(cls, value: bytes, pos: int) -> Tuple[int, int]:
        """
        Returns the (start_field, end_field) paid for a variant representing a year-month interval
        value stored at a given position in the value.
        """
        cls._check_index(pos, len(value))
        start_field = value[pos] & 0x1
        end_field = (value[pos] >> 1) & 0x1
        if end_field < start_field:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        return (start_field, end_field)

    @classmethod
    def _get_metadata_key(cls, metadata: bytes, id: int) -> str:
        """
        Returns the key string from the dictionary in the metadata, corresponding to `id`.
        """
        cls._check_index(0, len(metadata))
        offset_size = ((metadata[0] >> 6) & 0x3) + 1
        dict_size = cls._read_long(metadata, 1, offset_size, signed=False)
        if id >= dict_size:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        string_start = 1 + (dict_size + 2) * offset_size
        offset = cls._read_long(metadata, 1 + (id + 1) * offset_size, offset_size, signed=False)
        next_offset = cls._read_long(
            metadata, 1 + (id + 2) * offset_size, offset_size, signed=False
        )
        if offset > next_offset:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        cls._check_index(string_start + next_offset - 1, len(metadata))
        return metadata[string_start + offset : (string_start + next_offset)].decode("utf-8")

    @classmethod
    def _get_boolean(cls, value: bytes, pos: int) -> bool:
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.PRIMITIVE or (
            type_info != VariantUtils.TRUE and type_info != VariantUtils.FALSE
        ):
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        return type_info == VariantUtils.TRUE

    @classmethod
    def _get_long(cls, value: bytes, pos: int) -> int:
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.PRIMITIVE:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        if type_info == VariantUtils.INT1:
            return cls._read_long(value, pos + 1, 1, signed=True)
        elif type_info == VariantUtils.INT2:
            return cls._read_long(value, pos + 1, 2, signed=True)
        elif type_info == VariantUtils.INT4 or type_info == VariantUtils.DATE:
            return cls._read_long(value, pos + 1, 4, signed=True)
        elif type_info == VariantUtils.INT8:
            return cls._read_long(value, pos + 1, 8, signed=True)
        raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _get_date(cls, value: bytes, pos: int) -> datetime.date:
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.PRIMITIVE:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        if type_info == VariantUtils.DATE:
            days_since_epoch = cls._read_long(value, pos + 1, 4, signed=True)
            return datetime.date.fromordinal(VariantUtils.EPOCH.toordinal() + days_since_epoch)
        raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.datetime:
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.PRIMITIVE:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        if type_info == VariantUtils.TIMESTAMP_NTZ:
            microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True)
            return VariantUtils.EPOCH_NTZ + datetime.timedelta(
                microseconds=microseconds_since_epoch
            )
        if type_info == VariantUtils.TIMESTAMP:
            microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True)
            return (
                VariantUtils.EPOCH + datetime.timedelta(microseconds=microseconds_since_epoch)
            ).astimezone(ZoneInfo(zone_id))
        raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _get_yminterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]:
        """
        Returns the (months, start_field, end_field) tuple from a year-month interval value at a
        given position in a variant.
        """
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.PRIMITIVE:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        if type_info == VariantUtils.YEAR_MONTH_INTERVAL:
            months = cls._read_long(value, pos + 2, 4, signed=True)
            start_field, end_field = cls._get_year_month_interval_fields(value, pos + 1)
            return (months, start_field, end_field)
        raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _get_dtinterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]:
        """
        Returns the (micros, start_field, end_field) tuple from a day-time interval value at a given
        position in a variant.
        """
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.PRIMITIVE:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        if type_info == VariantUtils.DAY_TIME_INTERVAL:
            micros = cls._read_long(value, pos + 2, 8, signed=True)
            start_field, end_field = cls._get_day_time_interval_fields(value, pos + 1)
            return (micros, start_field, end_field)
        raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _get_string(cls, value: bytes, pos: int) -> str:
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type == VariantUtils.SHORT_STR or (
            basic_type == VariantUtils.PRIMITIVE and type_info == VariantUtils.LONG_STR
        ):
            start = 0
            length = 0
            if basic_type == VariantUtils.SHORT_STR:
                start = pos + 1
                length = type_info
            else:
                start = pos + 1 + VariantUtils.U32_SIZE
                length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False)
            cls._check_index(start + length - 1, len(value))
            return value[start : start + length].decode("utf-8")
        raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _get_double(cls, value: bytes, pos: int) -> float:
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.PRIMITIVE:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        if type_info == VariantUtils.FLOAT:
            cls._check_index(pos + 4, len(value))
            return struct.unpack("<f", value[pos + 1 : pos + 5])[0]
        elif type_info == VariantUtils.DOUBLE:
            cls._check_index(pos + 8, len(value))
            return struct.unpack("<d", value[pos + 1 : pos + 9])[0]
        raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _check_decimal(cls, unscaled: int, scale: int, max_unscaled: int, max_scale: int) -> None:
        # max_unscaled == 10**max_scale, but we pass a literal parameter to avoid redundant
        # computation.
        if unscaled >= max_unscaled or unscaled <= -max_unscaled or scale > max_scale:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal:
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.PRIMITIVE:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        scale = value[pos + 1]
        unscaled = 0
        if type_info == VariantUtils.DECIMAL4:
            unscaled = cls._read_long(value, pos + 2, 4, signed=True)
            cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL4_VALUE, cls.MAX_DECIMAL4_PRECISION)
        elif type_info == VariantUtils.DECIMAL8:
            unscaled = cls._read_long(value, pos + 2, 8, signed=True)
            cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL8_VALUE, cls.MAX_DECIMAL8_PRECISION)
        elif type_info == VariantUtils.DECIMAL16:
            cls._check_index(pos + 17, len(value))
            unscaled = int.from_bytes(value[pos + 2 : pos + 18], byteorder="little", signed=True)
            cls._check_decimal(
                unscaled, scale, cls.MAX_DECIMAL16_VALUE, cls.MAX_DECIMAL16_PRECISION
            )
        else:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale))

    @classmethod
    def _get_binary(cls, value: bytes, pos: int) -> bytes:
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.PRIMITIVE or type_info != VariantUtils.BINARY:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        start = pos + 1 + VariantUtils.U32_SIZE
        length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False)
        cls._check_index(start + length - 1, len(value))
        return bytes(value[start : start + length])

    @classmethod
    def _get_type(cls, value: bytes, pos: int) -> Any:
        """
        Returns the Python type of the Variant at the given position.
        """
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type == VariantUtils.SHORT_STR:
            return str
        elif basic_type == VariantUtils.OBJECT:
            return dict
        elif basic_type == VariantUtils.ARRAY:
            return array
        elif type_info == VariantUtils.NULL:
            return type(None)
        elif type_info == VariantUtils.TRUE or type_info == VariantUtils.FALSE:
            return bool
        elif (
            type_info == VariantUtils.INT1
            or type_info == VariantUtils.INT2
            or type_info == VariantUtils.INT4
            or type_info == VariantUtils.INT8
        ):
            return int
        elif type_info == VariantUtils.DOUBLE or type_info == VariantUtils.FLOAT:
            return float
        elif (
            type_info == VariantUtils.DECIMAL4
            or type_info == VariantUtils.DECIMAL8
            or type_info == VariantUtils.DECIMAL16
        ):
            return decimal.Decimal
        elif type_info == VariantUtils.BINARY:
            return bytes
        elif type_info == VariantUtils.DATE:
            return datetime.date
        elif type_info == VariantUtils.TIMESTAMP or type_info == VariantUtils.TIMESTAMP_NTZ:
            return datetime.datetime
        elif type_info == VariantUtils.LONG_STR:
            return str
        elif type_info == VariantUtils.DAY_TIME_INTERVAL:
            return datetime.timedelta
        elif type_info == VariantUtils.YEAR_MONTH_INTERVAL:
            return cls._PlaceholderYearMonthIntervalInternalType
        raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _to_year_month_interval_ansi_string(
        cls, months: int, start_field: int, end_field: int
    ) -> str:
        """
        Used to convert months representing a year-month interval with given start and end
        fields to its ANSI SQL string representation.
        """
        YEAR = 0
        MONTHS_PER_YEAR = 12
        sign = ""
        abs_months = months
        if months < 0:
            sign = "-"
            abs_months = -abs_months
        year = sign + str(abs_months // MONTHS_PER_YEAR)
        year_and_month = year + "-" + str(abs_months % MONTHS_PER_YEAR)
        format_builder = ["INTERVAL '"]
        if start_field == end_field:
            if start_field == YEAR:
                format_builder.append(year + "' YEAR")
            else:
                format_builder.append(str(months) + "' MONTH")
        else:
            format_builder.append(year_and_month + "' YEAR TO MONTH")
        return "".join(format_builder)

    @classmethod
    def _to_day_time_interval_ansi_string(
        cls, micros: int, start_field: int, end_field: int
    ) -> str:
        """
        Used to convert microseconds representing a day-tine interval with given start and end
        fields to its ANSI SQL string representation.
        """
        DAY = 0
        HOUR = 1
        MINUTE = 2
        SECOND = 3
        MIN_LONG_VALUE = -9223372036854775808
        MAX_LONG_VALUE = 9223372036854775807
        MICROS_PER_SECOND = 1000 * 1000
        MICROS_PER_MINUTE = MICROS_PER_SECOND * 60
        MICROS_PER_HOUR = MICROS_PER_MINUTE * 60
        MICROS_PER_DAY = MICROS_PER_HOUR * 24
        MAX_SECOND = MAX_LONG_VALUE // MICROS_PER_SECOND
        MAX_MINUTE = MAX_LONG_VALUE // MICROS_PER_MINUTE
        MAX_HOUR = MAX_LONG_VALUE // MICROS_PER_HOUR
        MAX_DAY = MAX_LONG_VALUE // MICROS_PER_DAY

        def field_to_string(field: int) -> str:
            if field == DAY:
                return "DAY"
            elif field == HOUR:
                return "HOUR"
            elif field == MINUTE:
                return "MINUTE"
            elif field == SECOND:
                return "SECOND"
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

        if end_field < start_field:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        sign = ""
        rest = micros
        from_str = field_to_string(start_field).upper()
        to_str = field_to_string(end_field).upper()
        prefix = "INTERVAL '"
        postfix = f"' {from_str}" if (start_field == end_field) else f"' {from_str} TO {to_str}"
        if micros < 0:
            if micros == MIN_LONG_VALUE:
                # Especial handling of minimum `Long` value because negate op overflows `Long`.
                # seconds = 106751991 * (24 * 60 * 60) + 4 * 60 * 60 + 54 = 9223372036854
                # microseconds = -9223372036854000000L-775808 == Long.MinValue
                base_str = "-106751991 04:00:54.775808000"
                first_str = "-" + (
                    str(MAX_DAY)
                    if (start_field == DAY)
                    else (
                        str(MAX_HOUR)
                        if (start_field == HOUR)
                        else (
                            str(MAX_MINUTE)
                            if (start_field == MINUTE)
                            else str(MAX_SECOND) + ".775808"
                        )
                    )
                )
                if start_field == end_field:
                    return prefix + first_str + postfix
                else:
                    substr_start = (
                        10 if (start_field == DAY) else (13 if (start_field == HOUR) else 16)
                    )
                    substr_end = (
                        13 if (end_field == HOUR) else (16 if (end_field == MINUTE) else 26)
                    )
                    return prefix + first_str + base_str[substr_start:substr_end] + postfix
            else:
                sign = "-"
                rest = -rest
        format_builder = [sign]
        format_args = []
        if start_field == DAY:
            format_builder.append(str(rest // MICROS_PER_DAY))
            rest %= MICROS_PER_DAY
        elif start_field == HOUR:
            format_builder.append("%02d")
            format_args.append(rest // MICROS_PER_HOUR)
            rest %= MICROS_PER_HOUR
        elif start_field == MINUTE:
            format_builder.append("%02d")
            format_args.append(rest // MICROS_PER_MINUTE)
            rest %= MICROS_PER_MINUTE
        elif start_field == SECOND:
            lead_zero = "0" if (rest < 10 * MICROS_PER_SECOND) else ""
            format_builder.append(
                lead_zero + (Decimal(rest) / Decimal(1000000)).normalize().to_eng_string()
            )

        if start_field < HOUR and HOUR <= end_field:
            format_builder.append(" %02d")
            format_args.append(rest // MICROS_PER_HOUR)
            rest %= MICROS_PER_HOUR
        if start_field < MINUTE and MINUTE <= end_field:
            format_builder.append(":%02d")
            format_args.append(rest // MICROS_PER_MINUTE)
            rest %= MICROS_PER_MINUTE
        if start_field < SECOND and SECOND <= end_field:
            lead_zero = "0" if (rest < 10 * MICROS_PER_SECOND) else ""
            format_builder.append(
                ":" + lead_zero + (Decimal(rest) / Decimal(1000000)).normalize().to_eng_string()
            )
        return prefix + ("".join(format_builder) % tuple(format_args)) + postfix

    @classmethod
    def _to_json(cls, value: bytes, metadata: bytes, pos: int, zone_id: str) -> str:
        variant_type = cls._get_type(value, pos)
        if variant_type == dict:

            def handle_object(key_value_pos_list: List[Tuple[str, int]]) -> str:
                key_value_list = [
                    json.dumps(key) + ":" + cls._to_json(value, metadata, value_pos, zone_id)
                    for (key, value_pos) in key_value_pos_list
                ]
                return "{" + ",".join(key_value_list) + "}"

            return cls._handle_object(value, metadata, pos, handle_object)
        elif variant_type == array:

            def handle_array(value_pos_list: List[int]) -> str:
                value_list = [
                    cls._to_json(value, metadata, value_pos, zone_id)
                    for value_pos in value_pos_list
                ]
                return "[" + ",".join(value_list) + "]"

            return cls._handle_array(value, pos, handle_array)
        elif variant_type == datetime.timedelta:
            micros, start_field, end_field = cls._get_dtinterval_info(value, pos)
            return '"' + cls._to_day_time_interval_ansi_string(micros, start_field, end_field) + '"'
        elif variant_type == cls._PlaceholderYearMonthIntervalInternalType:
            months, start_field, end_field = cls._get_yminterval_info(value, pos)
            return (
                '"' + cls._to_year_month_interval_ansi_string(months, start_field, end_field) + '"'
            )
        else:
            value = cls._get_scalar(variant_type, value, metadata, pos, zone_id)
            if value is None:
                return "null"
            if type(value) == bool:
                return "true" if value else "false"
            if type(value) == str:
                return json.dumps(value)
            if type(value) == bytes:
                # decoding simply converts byte array to string
                return '"' + base64.b64encode(value).decode("utf-8") + '"'
            if type(value) == datetime.date or type(value) == datetime.datetime:
                return '"' + str(value) + '"'
            return str(value)

    @classmethod
    def _to_python(cls, value: bytes, metadata: bytes, pos: int) -> Any:
        variant_type = cls._get_type(value, pos)
        if variant_type == dict:

            def handle_object(key_value_pos_list: List[Tuple[str, int]]) -> Dict[str, Any]:
                key_value_list = [
                    (key, cls._to_python(value, metadata, value_pos))
                    for (key, value_pos) in key_value_pos_list
                ]
                return dict(key_value_list)

            return cls._handle_object(value, metadata, pos, handle_object)
        elif variant_type == array:

            def handle_array(value_pos_list: List[int]) -> List[Any]:
                value_list = [
                    cls._to_python(value, metadata, value_pos) for value_pos in value_pos_list
                ]
                return value_list

            return cls._handle_array(value, pos, handle_array)
        elif variant_type == datetime.timedelta:
            # day-time intervals are represented using timedelta in a trivial manner
            return datetime.timedelta(microseconds=cls._get_dtinterval_info(value, pos)[0])
        elif variant_type == cls._PlaceholderYearMonthIntervalInternalType:
            raise PySparkNotImplementedError(
                errorClass="NOT_IMPLEMENTED",
                messageParameters={"feature": "VariantUtils.YEAR_MONTH_INTERVAL"},
            )
        else:
            return cls._get_scalar(variant_type, value, metadata, pos, zone_id="UTC")

    @classmethod
    def _get_scalar(
        cls, variant_type: Any, value: bytes, metadata: bytes, pos: int, zone_id: str
    ) -> Any:
        if isinstance(None, variant_type):
            return None
        elif variant_type == bool:
            return cls._get_boolean(value, pos)
        elif variant_type == int:
            return cls._get_long(value, pos)
        elif variant_type == str:
            return cls._get_string(value, pos)
        elif variant_type == float:
            return cls._get_double(value, pos)
        elif variant_type == decimal.Decimal:
            return cls._get_decimal(value, pos)
        elif variant_type == bytes:
            return cls._get_binary(value, pos)
        elif variant_type == datetime.date:
            return cls._get_date(value, pos)
        elif variant_type == datetime.datetime:
            return cls._get_timestamp(value, pos, zone_id)
        else:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})

    @classmethod
    def _handle_object(
        cls, value: bytes, metadata: bytes, pos: int, func: Callable[[List[Tuple[str, int]]], Any]
    ) -> Any:
        """
        Parses the variant object at position `pos`.
        Calls `func` with a list of (key, value position) pairs of the object.
        """
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.OBJECT:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        large_size = ((type_info >> 4) & 0x1) != 0
        size_bytes = VariantUtils.U32_SIZE if large_size else 1
        num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
        id_size = ((type_info >> 2) & 0x3) + 1
        offset_size = ((type_info) & 0x3) + 1
        id_start = pos + 1 + size_bytes
        offset_start = id_start + num_fields * id_size
        data_start = offset_start + (num_fields + 1) * offset_size

        key_value_pos_list = []
        for i in range(num_fields):
            id = cls._read_long(value, id_start + id_size * i, id_size, signed=False)
            offset = cls._read_long(
                value, offset_start + offset_size * i, offset_size, signed=False
            )
            value_pos = data_start + offset
            key_value_pos_list.append((cls._get_metadata_key(metadata, id), value_pos))
        return func(key_value_pos_list)

    @classmethod
    def _handle_array(cls, value: bytes, pos: int, func: Callable[[List[int]], Any]) -> Any:
        """
        Parses the variant array at position `pos`.
        Calls `func` with a list of element positions of the array.
        """
        cls._check_index(pos, len(value))
        basic_type, type_info = cls._get_type_info(value, pos)
        if basic_type != VariantUtils.ARRAY:
            raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
        large_size = ((type_info >> 2) & 0x1) != 0
        size_bytes = VariantUtils.U32_SIZE if large_size else 1
        num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
        offset_size = (type_info & 0x3) + 1
        offset_start = pos + 1 + size_bytes
        data_start = offset_start + (num_fields + 1) * offset_size

        value_pos_list = []
        for i in range(num_fields):
            offset = cls._read_long(
                value, offset_start + offset_size * i, offset_size, signed=False
            )
            element_pos = data_start + offset
            value_pos_list.append(element_pos)
        return func(value_pos_list)
